Multi label vs. Multi class
An expriment in loss functions on the same dataset.
- Tank classifier models:
- Bulding a model using multi class loss
- Bulding a model using multi label loss
- same model but with assign loss function
- Addtional validation
tank_types = 'merkava mk4','M1 Abrams','water'
path = Path('tanks')
#downloading 150 images to labeled directories
if not path.exists():
path.mkdir()
for o in tank_types:
dest = (path/o)
dest.mkdir(exist_ok=True)
urls = search_images_ddg(f'{o} tank', max_images=150)
download_images(dest, urls=urls)
#deleting not working files
fns = get_image_files(path)
failed = verify_images(fns)
failed.map(Path.unlink);
tanks = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=parent_label,
item_tfms=RandomResizedCrop(224, min_scale=0.5),
batch_tfms=aug_transforms())
dls = tanks.dataloaders(path)
dls.valid.show_batch(max_n=5, nrows=1)
learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(4)
plot_top_losses_fix(interp, 10, nrows=10)
def parentlabel(x):
return [x.parent.name] # as get_y recieve a list
tanks = DataBlock(
blocks=(ImageBlock, MultiCategoryBlock),
#MultiCategoryBlock(add_na=True)
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=parentlabel,
item_tfms=RandomResizedCrop(224, min_scale=0.5),
batch_tfms=aug_transforms())
dls = tanks.dataloaders(path)
dls.valid.show_batch(nrows=1, ncols=5)
multi_learner = cnn_learner(dls, resnet18, metrics=accuracy_multi) #threshold=0.5, Sigmod=True--->partial(accuracy_multi, thresh=0.95
multi_learner.fine_tune(4)
Fixing the top lose
def plot_top_losses_fix(interp, k, largest=True, **kwargs):
losses,idx = interp.top_losses(k, largest)
if not isinstance(interp.inputs, tuple): interp.inputs = (interp.inputs,)
if isinstance(interp.inputs[0], Tensor): inps = tuple(o[idx] for o in interp.inputs)
else: inps = interp.dl.create_batch(interp.dl.before_batch([tuple(o[i] for o in interp.inputs) for i in idx]))
b = inps + tuple(o[idx] for o in (interp.targs if is_listy(interp.targs) else (interp.targs,)))
x,y,its = interp.dl._pre_show_batch(b, max_n=k)
b_out = inps + tuple(o[idx] for o in (interp.decoded if is_listy(interp.decoded) else (interp.decoded,)))
x1,y1,outs = interp.dl._pre_show_batch(b_out, max_n=k)
if its is not None:
#plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), L(self.preds).itemgot(idx), losses, **kwargs)
plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), interp.preds[idx], losses, **kwargs)
#TODO: figure out if this is needed
#its None means that a batch knows how to show itself as a whole, so we pass x, x1
#else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)
intrpt = ClassificationInterpretation.from_learner(multi_learner)
plot_top_losses_fix(intrpt, 10, nrows=10)
multi_learner.metrics = partial(accuracy_multi, thresh=0.8)
multi_learner.validate()
preds, targs = multi_learner.get_preds()
xs = torch.linspace(0.05,0.95,29)
accs = [accuracy_multi(preds, targs, thresh=i, sigmoid=False) for i in xs]
plt.plot(xs,accs);
multi_learner.loss_func.thresh=0.2
multi_learner.loss_func.thresh
multi_learner.loss_func.thresh=0.8
multi_learner.export('mlmodel.pkl')
ls=nn.BCEWithLogitsLoss()
multi_learner2 = cnn_learner(dls, resnet18, loss_func=ls, metrics=accuracy_multi)
multi_learner2.fine_tune(4)
multi_learner2.get_preds(with_decoded=False)
Making multi label classifier on the same data---->own loss
#loss_func=BCEWithLogitsLossFlat(thresh=0.7)
lr = cnn_learner(dls, resnet18, loss_func=F.binary_cross_entropy_with_logits, metrics=partial(accuracy_multi, thresh=0.2) )
To use our model in an application, we can simply treat the predict method as a regular function.
# path = Path()
learn_inf = load_learner(path/'export.pkl')
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()
btn_run = widgets.Button(description='Classify')
def on_click_classify(change):
img = PILImage.create(btn_upload.data[-1])
out_pl.clear_output()
with out_pl: display(img.to_thumb(128,128))
pred,pred_idx,probs = learn_inf.predict(img)
lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
btn_run.on_click(on_click_classify)
VBox([widgets.Label('Select your tank!'),
btn_upload, btn_run, out_pl, lbl_pred])
multi_learner2.predict('/content/DUDU.jpg')
multi_learner2.predict('/content/TetroWtank.jpg')
multi_learner.predict('/content/manytanks.jpg')
multi_learner.predict('/content/merk_abrms2.jpeg')
multi_learner.predict('/content/merkav_abrams.jpg')
multi_learner.predict('/content/mek_abs_3.jpg')
#img.show()
#learn.predict(img)[0]